-
Notifications
You must be signed in to change notification settings - Fork 129
Allow string keys in eval
utility
#242
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
eval
utility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @Raj-Parekh24
I left some suggestions below
While running When I check its logs for basic.py there are no errors from my changes. Is this expected behavior or I have broken something? I am using python 3.11 |
And in the last build, the test case failed for multiple environments is: |
Yeah needs to be investigated. Probably a new release of jax that behaves differently. The tracking issue is #241 |
Mypy doesn't work on python 3.11, issue here #240 |
Thanks for the help @ricardoV94 , I have updated the code and added warning when multiple variables are there for same name. And also updated the test cases accordingly. Please review it. |
pytensor/graph/basic.py
Outdated
warnings.warn( | ||
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems safer to just fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of warning will throw Exception instead.
pytensor/graph/basic.py
Outdated
>>> import numpy as np | ||
>>> import pytensor.tensor as at | ||
>>> x = at.dscalar('x') | ||
>>> y = at.dscalar('y') | ||
>>> z = x + y | ||
>>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4) | ||
True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not show a use of the function. In any case we don't need a docstring with an example I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are making it internal to eval I also think that the doc-string will be no more required.
pytensor/graph/basic.py
Outdated
@@ -558,6 +558,46 @@ def get_parents(self): | |||
return [self.owner] | |||
return [] | |||
|
|||
def convert_string_keys_to_pytensor_variables(self, inputs_to_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given we are in pytensor
, no need to mention it in the name. If you want you can specify the return type is Dict[Variable, Variable]
def convert_string_keys_to_pytensor_variables(self, inputs_to_values): | |
def convert_string_keys_to_variables(self, inputs_to_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, will replace the function name with suggested one.
tests/graph/test_basic.py
Outdated
assert self.w.eval({self.z: 3}) == 6.0 | ||
|
||
def test_eval_with_strings_with_mulitple_same_name(self): | ||
assert self.t.eval({"e": 1.0}) == 2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would create the new variables here instead of setup_method
. It's arguably more readable and these are unlikely to be used in other tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, instead of creating variables in setup_method will add the variables in test_function.
@Raj-Parekh24 nice work. I left some comments above. Let me know what you think. |
Hi @ricardoV94 , Thanks for the review and the suggestions, I have updated the implementation accordingly and pushed a new commit regarding it. Please review it and suggest the required changes. |
Hey @ricardoV94 , I have updated the code bases on your suggestions. And I also want to work on issue #55, I got the requirements but still not able to understand how to implement it. Since I am new to the repo, can you please give me starting steps to work on it. |
Hi @Raj-Parekh24 sorry for the delay. I'll try to review your PR tomorrow. For that other PR I will give some more pointers there Thanks for all the work! |
Thanks for all the help and support @ricardoV94 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I left some suggestions below
pytensor/graph/basic.py
Outdated
nodes_with_matching_names[ | ||
length_of_nodes_with_matching_names - 1 | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be nodes_with_matching_names[0]
, now that you checked only 1 case exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review @ricardoV94. Will update the index value.
Thanks a lot for your review @ricardoV94, I have added the code based on what you suggested. Please review it and kindly give your suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionality-wise this PR is ready! I just left some comments on code formatting. Let me know if you have any questions.
tests/graph/test_basic.py
Outdated
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0 | ||
assert self.w.eval({self.z: 3}) == 6.0 | ||
|
||
def test_eval_errors_having_mulitple_variables_same_name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_eval_errors_having_mulitple_variables_same_name(self): | |
def test_eval_with_strings_multiple_matches(self): |
tests/graph/test_basic.py
Outdated
with pytest.raises(Exception, match="Found 2 pytensor variables with name e"): | ||
t.eval({"e": 1}) | ||
|
||
def test_eval_errors_with_no_name_exists(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_eval_errors_with_no_name_exists(self): | |
def test_eval_with_strings_no_match(self): |
pytensor/graph/basic.py
Outdated
def convert_string_keys_to_variables(): | ||
process_input_to_values = {} | ||
for i in inputs_to_values: | ||
if isinstance(i, str): | ||
nodes_with_matching_names = get_var_by_name([self], i) | ||
length_of_nodes_with_matching_names = len(nodes_with_matching_names) | ||
if length_of_nodes_with_matching_names == 0: | ||
raise Exception(f"{i} not found in graph") | ||
else: | ||
if length_of_nodes_with_matching_names > 1: | ||
raise Exception( | ||
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}" | ||
) | ||
process_input_to_values[ | ||
nodes_with_matching_names[0] | ||
] = inputs_to_values[i] | ||
else: | ||
process_input_to_values[i] = inputs_to_values[i] | ||
return process_input_to_values | ||
|
||
inputs_to_values = convert_string_keys_to_variables() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be a bit more "pythonic" like this (I did not test the code!):
def convert_string_keys_to_variables(): | |
process_input_to_values = {} | |
for i in inputs_to_values: | |
if isinstance(i, str): | |
nodes_with_matching_names = get_var_by_name([self], i) | |
length_of_nodes_with_matching_names = len(nodes_with_matching_names) | |
if length_of_nodes_with_matching_names == 0: | |
raise Exception(f"{i} not found in graph") | |
else: | |
if length_of_nodes_with_matching_names > 1: | |
raise Exception( | |
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}" | |
) | |
process_input_to_values[ | |
nodes_with_matching_names[0] | |
] = inputs_to_values[i] | |
else: | |
process_input_to_values[i] = inputs_to_values[i] | |
return process_input_to_values | |
inputs_to_values = convert_string_keys_to_variables() | |
def convert_string_keys_to_variables(input_to_values): | |
new_input_to_values = {} | |
for key, value in inputs_to_values.items(): | |
if isinstance(key, str): | |
matching_vars = get_var_by_name([self], key) | |
if not matching_vars: | |
raise Exception(f"{key} not found in graph") | |
elif len(matching_vars) > 1: | |
raise Exception( | |
f"Found multiple variables with name {key}" | |
) | |
input_to_values[matching_vars[0]] = value | |
else: | |
new_input_to_values[key] = value | |
return new_input_to_values | |
inputs_to_values = convert_string_keys_to_variables(inputs_to_values) |
Thanks for the review @ricardoV94 , I have incorporated the suggested changes in the new commit, please review it. I also wanted to implement a feature under issue #55, the initial points can help me to make things clear. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, running the tests now!
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #242 +/- ##
=======================================
Coverage 80.44% 80.44%
=======================================
Files 170 170
Lines 45328 45341 +13
Branches 11069 11073 +4
=======================================
+ Hits 36463 36476 +13
Misses 6642 6642
Partials 2223 2223
|
Motivation for these changes
Closes #227
Implementation details
I have implemented a function that converts the dictionary with string keys to pytensor. The pytensor variable for the corresponding name is found using the function "get_var_by_name".
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance